import os
import pandas as pd
import numpy as np
import glob
import re

def calculate_relative_improvement(lru_value, tlru_value):
    """Calculate relative improvement percentage: (LRU - T-LRU)/LRU * 100"""
    if lru_value > 0:
        return (lru_value - tlru_value) / lru_value * 100
    else:
        return 0.0

def get_xi_from_directory(dir_path):
    """Extract xi value from directory path"""
    match = re.search(r'xi(\d+)', dir_path)
    if match:
        xi_value = int(match.group(1))
        return xi_value
    return None

def map_xi_to_ms(xi):
    """Map xi values to milliseconds for table headers"""
    xi_to_ms = {
        694: 50,    # 0.05s -> 50ms
        1498: 100,  # 0.1s -> 100ms
        2302: 150,  # 0.15s -> 150ms
        3107: 200,  # 0.2s -> 200ms
        7934: 500   # 0.5s -> 500ms
    }
    return xi_to_ms.get(xi)

def load_data():
    """Load all CSV files from results directories"""
    results = {}
    base_dir = "./results"
    
    # Find all latency comparison CSV files
    csv_files = glob.glob(f"{base_dir}/latency_200ms_comparison_xi*Q*forced*/latency_200ms_comparison.csv")
    
    for file_path in csv_files:
        # Extract xi value from directory name
        dir_path = os.path.dirname(file_path)
        xi = get_xi_from_directory(dir_path)
        
        if xi is not None:
            # Read the CSV file
            df = pd.read_csv(file_path)
            results[xi] = df
    
    return results

def generate_improvement_table(results, comparison_type='vanilla'):
    """Generate improvement table for a single comparison type"""
    # Get all available capacities and xi values
    all_capacities = []
    all_xi = sorted(results.keys())
    
    for xi, df in results.items():
        all_capacities.extend(df['Cache Capacity'].tolist())
    all_capacities = sorted(list(set(all_capacities)))
    
    # Create a DataFrame to hold the improvements
    improvement_data = []
    
    for capacity in all_capacities:
        row_data = {'Capacity': capacity}
        
        for xi in all_xi:
            df = results[xi]
            capacity_row = df[df['Cache Capacity'] == capacity]
            
            if not capacity_row.empty:
                tlru_value = capacity_row['T-LRU Turns > 200ms'].values[0]
                
                if comparison_type == 'vanilla':
                    comparison_col = 'Vanilla LRU Turns > 200ms'
                    ms_value = map_xi_to_ms(xi)
                    col_name = f"$\\xi$ = {ms_value}ms"
                else:  # threshold
                    comparison_col = 'Threshold LRU Turns > 200ms'
                    ms_value = map_xi_to_ms(xi)
                    col_name = f"$\\xi$ = {ms_value}ms"
                
                comparison_value = capacity_row[comparison_col].values[0]
                improvement = calculate_relative_improvement(comparison_value, tlru_value)
                row_data[col_name] = improvement
            else:
                row_data[f"$\\xi$ = {map_xi_to_ms(xi)}ms"] = np.nan
        
        improvement_data.append(row_data)
    
    # Convert to DataFrame
    improvement_df = pd.DataFrame(improvement_data)
    
    return improvement_df

def generate_merged_improvement_table(results):
    """Generate a merged table with both comparison types"""
    # Get all available capacities and xi values
    all_capacities = []
    all_xi = sorted(results.keys())
    
    for xi, df in results.items():
        all_capacities.extend(df['Cache Capacity'].tolist())
    all_capacities = sorted(list(set(all_capacities)))
    
    # Create a DataFrame to hold the improvements
    improvement_data = []
    
    for capacity in all_capacities:
        row_data = {'Capacity': capacity}
        
        for xi in all_xi:
            df = results[xi]
            capacity_row = df[df['Cache Capacity'] == capacity]
            
            if not capacity_row.empty:
                tlru_value = capacity_row['T-LRU Turns > 200ms'].values[0]
                vanilla_value = capacity_row['Vanilla LRU Turns > 200ms'].values[0]
                threshold_value = capacity_row['Threshold LRU Turns > 200ms'].values[0]
                
                ms_value = map_xi_to_ms(xi)
                # Create column names for both comparisons at this xi value
                vanilla_col_name = f"$\\xi$ = {ms_value}ms (LRU)"
                threshold_col_name = f"$\\xi$ = {ms_value}ms (Thre-LRU)"
                
                # Calculate improvements
                vanilla_improvement = calculate_relative_improvement(vanilla_value, tlru_value)
                threshold_improvement = calculate_relative_improvement(threshold_value, tlru_value)
                
                # Add to row data
                row_data[vanilla_col_name] = vanilla_improvement
                row_data[threshold_col_name] = threshold_improvement
            else:
                row_data[f"$\\xi$ = {map_xi_to_ms(xi)}ms (LRU)"] = np.nan
                row_data[f"$\\xi$ = {map_xi_to_ms(xi)}ms (Thre-LRU)"] = np.nan
        
        improvement_data.append(row_data)
    
    # Convert to DataFrame
    merged_df = pd.DataFrame(improvement_data)
    
    return merged_df

def format_improvement_value(value):
    """Format the improvement value with appropriate coloring using a gradient scale"""
    if pd.isna(value):
        return ""
    
    # Use a gradient color scheme based on the improvement value
    if value > 40:
        color = "57bb8a"  # Dark green (>20% improvement)
    elif value > 30:
        color = "85cdaa"  # Medium green (15-20% improvement) 
    elif value > 20:
        color = "96d4b6"  # Light medium green (10-15% improvement)
    elif value > 10:
        color = "b8e2ce"  # Light green (5-10% improvement)
    elif value > 5:
        color = "c9e9d9"  # Very light green (2-5% improvement)
    elif value > 0:
        color = "dcf1e6"  # Extremely light green (0-2% improvement)
    elif value == 0:
        color = "ffffff"  # White (no improvement)
    else:
        color = "ffd6d6"  # Light red (negative improvement - worse)
    
    return f"\\cellcolor[HTML]{{{color}}} {value:.1f}\\%"

def generate_latex_table(df, caption, label):
    """Generate a LaTeX table from the DataFrame"""
    # Format the DataFrame for LaTeX
    latex_df = df.copy()
    
    # Format all numeric columns
    for col in latex_df.columns:
        if col != 'Capacity':
            latex_df[col] = latex_df[col].apply(format_improvement_value)
    
    # Generate LaTeX code
    latex_code = "\\begin{table}[!htp]\\centering\n"
    latex_code += f"\\caption{{{caption}}}\n"
    latex_code += "\\scriptsize\n"
    
    # Create column format
    num_cols = len(latex_df.columns)
    col_format = "r" + "r" * (num_cols - 1)
    
    latex_code += f"\\begin{{tabular}}{{{col_format}}}\\toprule\n"
    
    # Header row
    header = " &\\multicolumn{1}{c}{".join([latex_df.columns[0]] + [col + "}" for col in latex_df.columns[1:]])
    latex_code += header + " \\\\\\cmidrule{1-" + str(num_cols) + "}\n"
    
    # Data rows
    for _, row in latex_df.iterrows():
        row_str = " & ".join([str(row['Capacity'])] + [str(row[col]) for col in latex_df.columns[1:]])
        latex_code += row_str + " \\\\\n"
    
    latex_code += "\\bottomrule\n"
    latex_code += f"\\end{{tabular}} \\label{{{label}}}\n"
    latex_code += "\\end{table}"
    
    return latex_code

def generate_merged_latex_table(df, caption, label):
    """Generate a LaTeX table with multi-column headers for the merged table"""
    # Format the DataFrame for LaTeX
    latex_df = df.copy()
    
    # Format all numeric columns
    for col in latex_df.columns:
        if col != 'Capacity':
            latex_df[col] = latex_df[col].apply(format_improvement_value)
    
    # Generate LaTeX code
    latex_code = "\\begin{table}[!htp]\\centering\n"
    latex_code += f"\\caption{{{caption}}}\n"
    latex_code += "\\scriptsize\n"
    
    # Create column format for merged table (with multicolumn headers)
    num_xi_values = (len(latex_df.columns) - 1) // 2
    col_format = "r" + ("rr" * num_xi_values)
    
    latex_code += f"\\begin{{tabular}}{{{col_format}}}\\toprule\n"
    
    # Create first header row with xi values spanning two columns each
    xi_headers = []
    for i in range(num_xi_values):
        # Extract xi value from column name (assuming format "$\xi$ = Xms (LRU)")
        xi_value = latex_df.columns[1 + i*2].split("(")[0].strip()
        xi_headers.append(f"\\multicolumn{{2}}{{c}}{{{xi_value}}}")
    
    first_header = "Capacity &" + " &".join(xi_headers)
    latex_code += first_header + " \\\\\n"
    
    # Create second header row with comparison types
    comparison_headers = []
    for i in range(num_xi_values):
        comparison_headers.append("LRU")
        comparison_headers.append("Thre-LRU")
    
    second_header = " &" + " &".join(comparison_headers)
    latex_code += second_header + " \\\\\\cmidrule{1-" + str(2*num_xi_values + 1) + "}\n"
    
    # Data rows
    for _, row in latex_df.iterrows():
        row_str = str(row['Capacity'])
        for col in latex_df.columns[1:]:
            row_str += " & " + str(row[col])
        latex_code += row_str + " \\\\\n"
    
    latex_code += "\\bottomrule\n"
    latex_code += f"\\end{{tabular}} \\label{{{label}}}\n"
    latex_code += "\\end{table}"
    
    return latex_code

def main():
    # Load data from all CSV files
    results = load_data()
    
    if not results:
        print("No data files found. Please run latency_threshold_caching.py first.")
        return
    
    # Generate individual improvement tables (original functionality)
    vanilla_improvement_df = generate_improvement_table(results, 'vanilla')
    threshold_improvement_df = generate_improvement_table(results, 'threshold')
    
    # Generate merged improvement table
    merged_improvement_df = generate_merged_improvement_table(results)
    
    # Generate LaTeX tables
    vanilla_latex = generate_latex_table(
        vanilla_improvement_df,
        "Relative improvement of T-LRU over Vanilla LRU: \\% reduction in requests with latency > 200ms",
        "tab:improvement-tlru-lru"
    )
    
    threshold_latex = generate_latex_table(
        threshold_improvement_df,
        "Relative improvement of T-LRU over Threshold LRU: \\% reduction in requests with latency > 200ms",
        "tab:improvement-tlru-thresholdlru"
    )
    
    merged_latex = generate_merged_latex_table(
        merged_improvement_df,
        "Relative improvement of T-LRU: \\% reduction in requests with latency > 200ms",
        "tab:merged-improvement-tlru"
    )
    
    # Save LaTeX tables to files
    with open("tlru_vs_vanilla_improvement_table.tex", "w") as f:
        f.write(vanilla_latex)
    
    with open("tlru_vs_threshold_improvement_table.tex", "w") as f:
        f.write(threshold_latex)
    
    with open("tlru_merged_improvement_table.tex", "w") as f:
        f.write(merged_latex)
    
    print("LaTeX tables generated successfully!")
    
    # Also save CSV versions for easy viewing
    vanilla_improvement_df.to_csv("tlru_vs_vanilla_improvement_table.csv", index=False)
    threshold_improvement_df.to_csv("tlru_vs_threshold_improvement_table.csv", index=False)
    merged_improvement_df.to_csv("tlru_merged_improvement_table.csv", index=False)
    
    print("CSV tables generated successfully!")
    
    # Print the merged table to console
    print("\nMerged Improvement Table:")
    print(merged_improvement_df.to_string())

if __name__ == "__main__":
    main() 